{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tool and Function Calling\n",
    "\n",
    "This guide demonstrates how to use SGLang’s **Tool Calling** functionality."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## OpenAI Compatible API"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Launching the Server"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from openai import OpenAI\n",
    "import json\n",
    "from sglang.utils import (\n",
    "    execute_shell_command,\n",
    "    wait_for_server,\n",
    "    terminate_process,\n",
    "    print_highlight,\n",
    ")\n",
    "\n",
    "server_process = execute_shell_command(\n",
    "    \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --tool-call-parser llama3 --port 30333 --host 0.0.0.0\"  # llama3\n",
    ")\n",
    "wait_for_server(\"http://localhost:30333\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that `--tool-call-parser` defines the parser used to interpret responses. Currently supported parsers include:\n",
    "\n",
    "- llama3: Llama 3.1 / 3.2 (e.g. meta-llama/Llama-3.1-8B-Instruct, meta-llama/Llama-3.2-1B-Instruct).\n",
    "- mistral: Mistral (e.g. mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-Nemo-Instruct-2407, mistralai/\n",
    "Mistral-Nemo-Instruct-2407, mistralai/Mistral-7B-v0.3).\n",
    "- qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define Tools for Function Call\n",
    "Below is a Python snippet that shows how to define a tool as a dictionary. The dictionary includes a tool name, a description, and property defined Parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define tools\n",
    "tools = [\n",
    "    {\n",
    "        \"type\": \"function\",\n",
    "        \"function\": {\n",
    "            \"name\": \"get_current_weather\",\n",
    "            \"description\": \"Get the current weather in a given location\",\n",
    "            \"parameters\": {\n",
    "                \"type\": \"object\",\n",
    "                \"properties\": {\n",
    "                    \"city\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"The city to find the weather for, e.g. 'San Francisco'\",\n",
    "                    },\n",
    "                    \"state\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"the two-letter abbreviation for the state that the city is\"\n",
    "                        \" in, e.g. 'CA' which would mean 'California'\",\n",
    "                    },\n",
    "                    \"unit\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"The unit to fetch the temperature in\",\n",
    "                        \"enum\": [\"celsius\", \"fahrenheit\"],\n",
    "                    },\n",
    "                },\n",
    "                \"required\": [\"city\", \"state\", \"unit\"],\n",
    "            },\n",
    "        },\n",
    "    }\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define Messages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_messages():\n",
    "    return [\n",
    "        {\n",
    "            \"role\": \"user\",\n",
    "            \"content\": \"What's the weather like in Boston today? Please respond with the format: Today's weather is :{function call result}\",\n",
    "        }\n",
    "    ]\n",
    "\n",
    "\n",
    "messages = get_messages()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initialize the Client"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize OpenAI-like client\n",
    "client = OpenAI(api_key=\"None\", base_url=\"http://0.0.0.0:30333/v1\")\n",
    "model_name = client.models.list().data[0].id"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###  Non-Streaming Request"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Non-streaming mode test\n",
    "response_non_stream = client.chat.completions.create(\n",
    "    model=model_name,\n",
    "    messages=messages,\n",
    "    temperature=0.8,\n",
    "    top_p=0.8,\n",
    "    stream=False,  # Non-streaming\n",
    "    tools=tools,\n",
    ")\n",
    "print_highlight(\"Non-stream response:\")\n",
    "print(response_non_stream)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Streaming Request"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Streaming mode test\n",
    "print_highlight(\"Streaming response:\")\n",
    "response_stream = client.chat.completions.create(\n",
    "    model=model_name,\n",
    "    messages=messages,\n",
    "    temperature=0.8,\n",
    "    top_p=0.8,\n",
    "    stream=True,  # Enable streaming\n",
    "    tools=tools,\n",
    ")\n",
    "\n",
    "chunks = []\n",
    "for chunk in response_stream:\n",
    "    chunks.append(chunk)\n",
    "    if chunk.choices[0].delta.tool_calls:\n",
    "        print(chunk.choices[0].delta.tool_calls[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "### Handle Tool Calls\n",
    "\n",
    "When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Non-Streaming Request**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "name_non_stream = response_non_stream.choices[0].message.tool_calls[0].function.name\n",
    "arguments_non_stream = (\n",
    "    response_non_stream.choices[0].message.tool_calls[0].function.arguments\n",
    ")\n",
    "\n",
    "print_highlight(f\"Final streamed function call name: {name_non_stream}\")\n",
    "print_highlight(f\"Final streamed function call arguments: {arguments_non_stream}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Streaming Request**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Parse and combine function call arguments\n",
    "arguments = []\n",
    "for chunk in chunks:\n",
    "    choice = chunk.choices[0]\n",
    "    delta = choice.delta\n",
    "    if delta.tool_calls:\n",
    "        tool_call = delta.tool_calls[0]\n",
    "        if tool_call.function.name:\n",
    "            print_highlight(f\"Streamed function call name: {tool_call.function.name}\")\n",
    "\n",
    "        if tool_call.function.arguments:\n",
    "            arguments.append(tool_call.function.arguments)\n",
    "            print(f\"Streamed function call arguments: {tool_call.function.arguments}\")\n",
    "\n",
    "# Combine all fragments into a single JSON string\n",
    "full_arguments = \"\".join(arguments)\n",
    "print_highlight(f\"Final streamed function call arguments: {full_arguments}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define a Tool Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This is a demonstration, define real function according to your usage.\n",
    "def get_current_weather(city: str, state: str, unit: \"str\"):\n",
    "    return (\n",
    "        f\"The weather in {city}, {state} is 85 degrees {unit}. It is \"\n",
    "        \"partly cloudly, with highs in the 90's.\"\n",
    "    )\n",
    "\n",
    "\n",
    "available_tools = {\"get_current_weather\": get_current_weather}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## Execute the Tool"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "call_data = json.loads(full_arguments)\n",
    "\n",
    "messages.append(\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": \"\",\n",
    "        \"tool_calls\": {\"name\": \"get_current_weather\", \"arguments\": full_arguments},\n",
    "    }\n",
    ")\n",
    "\n",
    "# Call the corresponding tool function\n",
    "tool_name = messages[-1][\"tool_calls\"][\"name\"]\n",
    "tool_to_call = available_tools[tool_name]\n",
    "result = tool_to_call(**call_data)\n",
    "print_highlight(f\"Function call result: {result}\")\n",
    "messages.append({\"role\": \"tool\", \"content\": result, \"name\": tool_name})\n",
    "\n",
    "print_highlight(f\"Updated message history: {messages}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Send Results Back to Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "final_response = client.chat.completions.create(\n",
    "    model=model_name,\n",
    "    messages=messages,\n",
    "    temperature=0.8,\n",
    "    top_p=0.8,\n",
    "    stream=False,\n",
    "    tools=tools,\n",
    ")\n",
    "print_highlight(\"Non-stream response:\")\n",
    "print(final_response)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Native API and SGLang Runtime (SRT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "import requests\n",
    "\n",
    "# generate an answer\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n",
    "\n",
    "messages = get_messages()\n",
    "\n",
    "input = tokenizer.apply_chat_template(\n",
    "    messages,\n",
    "    tokenize=False,\n",
    "    add_generation_prompt=True,\n",
    "    tools=tools,\n",
    ")\n",
    "\n",
    "gen_url = \"http://localhost:30333/generate\"\n",
    "gen_data = {\"text\": input, \"sampling_params\": {\"skip_special_tokens\": False}}\n",
    "gen_response = requests.post(gen_url, json=gen_data).json()[\"text\"]\n",
    "print(gen_response)\n",
    "\n",
    "# parse the response\n",
    "parse_url = \"http://localhost:30333/function_call\"\n",
    "\n",
    "function_call_input = {\n",
    "    \"text\": gen_response,\n",
    "    \"tool_call_parser\": \"llama3\",\n",
    "    \"tools\": tools,\n",
    "}\n",
    "\n",
    "function_call_response = requests.post(parse_url, json=function_call_input)\n",
    "function_call_response_json = function_call_response.json()\n",
    "print(\"function name: \", function_call_response_json[\"calls\"][0][\"name\"])\n",
    "print(\"function arguments: \", function_call_response_json[\"calls\"][0][\"parameters\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "terminate_process(server_process)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Offline Engine API"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sglang as sgl\n",
    "from sglang.srt.function_call_parser import FunctionCallParser\n",
    "from sglang.srt.managers.io_struct import Tool, Function\n",
    "\n",
    "llm = sgl.Engine(model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n",
    "tokenizer = llm.tokenizer_manager.tokenizer\n",
    "input_ids = tokenizer.apply_chat_template(\n",
    "    messages, tokenize=True, add_generation_prompt=True, tools=tools\n",
    ")\n",
    "\n",
    "sampling_params = {\n",
    "    \"max_new_tokens\": 128,\n",
    "    \"temperature\": 0.3,\n",
    "    \"top_p\": 0.95,\n",
    "    \"skip_special_tokens\": False,\n",
    "}\n",
    "\n",
    "# 1) Offline generation\n",
    "result = llm.generate(input_ids=input_ids, sampling_params=sampling_params)\n",
    "generated_text = result[\"text\"]  # Assume there is only one prompt\n",
    "\n",
    "print(\"=== Offline Engine Output Text ===\")\n",
    "print(generated_text)\n",
    "\n",
    "\n",
    "# 2) Parse using FunctionCallParser\n",
    "def convert_dict_to_tool(tool_dict: dict) -> Tool:\n",
    "    function_dict = tool_dict.get(\"function\", {})\n",
    "    return Tool(\n",
    "        type=tool_dict.get(\"type\", \"function\"),\n",
    "        function=Function(\n",
    "            name=function_dict.get(\"name\"),\n",
    "            description=function_dict.get(\"description\"),\n",
    "            parameters=function_dict.get(\"parameters\"),\n",
    "        ),\n",
    "    )\n",
    "\n",
    "\n",
    "tools = [convert_dict_to_tool(raw_tool) for raw_tool in tools]\n",
    "\n",
    "parser = FunctionCallParser(tools=tools, tool_call_parser=\"llama3\")\n",
    "normal_text, calls = parser.parse_non_stream(generated_text)\n",
    "\n",
    "print(\"\\n=== Parsing Result ===\")\n",
    "print(\"Normal text portion:\", normal_text)\n",
    "print(\"Function call portion:\")\n",
    "for call in calls:\n",
    "    # call: ToolCallItem\n",
    "    print(f\"  - tool name: {call.name}\")\n",
    "    print(f\"    parameters: {call.parameters}\")\n",
    "\n",
    "# 3) If needed, perform additional logic on the parsed functions, such as automatically calling the corresponding function to obtain a return value, etc."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "llm.shutdown()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## How to support a new model?\n",
    "1. Update the TOOLS_TAG_LIST in sglang/srt/function_call_parser.py with the model’s tool tags. Currently supported tags include:\n",
    "```\n",
    "\tTOOLS_TAG_LIST = [\n",
    "\t    “<|plugin|>“,\n",
    "\t    “<function=“,\n",
    "\t    “<tool_call>“,\n",
    "\t    “<|python_tag|>“,\n",
    "\t    “[TOOL_CALLS]”\n",
    "\t]\n",
    "```\n",
    "2. Create a new detector class in sglang/srt/function_call_parser.py that inherits from BaseFormatDetector. The detector should handle the model’s specific function call format. For example:\n",
    "```\n",
    "    class NewModelDetector(BaseFormatDetector):\n",
    "```\n",
    "3. Add the new detector to the MultiFormatParser class that manages all the format detectors."
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
